import argparse
import jsonlines

import time
import pickle

from torch.utils.data import TensorDataset
import torch

import numpy as np
import os
import tqdm

def get_argument_parser():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument("--bert_model", default="roberta-large", type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model checkpoints and predictions will be written.")

    # Other parameters
    parser.add_argument("--train_file", default=None, type=str, help="jsonl for training. E.g., train-v1.1.json")
    parser.add_argument("--val_file", default=None, type=str,help="dev-v1.1.json or test-v1.1.json")
    parser.add_argument("--max_seq_length", default=384, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
                             "longer than this will be truncated, and sequences shorter than this will be padded.")
    parser.add_argument("--max_query_length", default=64, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
    parser.add_argument("--do_predict", action='store_true', help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.")
    parser.add_argument("--predict_batch_size", default=32, type=int, help="Total batch size for predictions.")
    parser.add_argument("--learning_rate", default=1e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--epochs", default=6, type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion", default=0.1, type=float,
                        help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
                             "of training.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Whether to lower case the input text. True for uncased models, False for cased models.")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--wall_clock_breakdown',
                        action='store_true',
                        default=False,
                        help="Whether to display the breakdown of the wall-clock time for foraward, backward and step")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--model_file",
                        type=str,
                        default="0",
                        help="Path to the Pretrained BERT Encoder File.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--max_grad_norm",
                        default=1.,
                        type=float,
                        help="Gradient clipping for FusedAdam.")
#     parser.add_argument('--job_name',
#                         type=str,
#                         default=None,
#                         help='Output path for Tensorboard event files.')
    parser.add_argument('--loss_plot_alpha',
                        type=float,
                        default=0.2,
                        help='Alpha factor for plotting moving average of loss.')

    parser.add_argument('--max_steps',
                        type=int,
                        default=100000,
                        help='Maximum number of training steps of effective batch size to complete.')

    parser.add_argument('--max_steps_per_epoch',
                        type=int,
                        default=100000,
                        help='Maximum number of training steps of effective batch size within an epoch to complete.')

    parser.add_argument('--print_steps',
                        type=int,
                        default=100,
                        help='Interval to print training details.')
    
    parser.add_argument('--model_type',
                       type=str,
                       default="bert-large-uncased",
                       help="Type of BERT model")
    
    parser.add_argument(
        '--load_checkpoint',
        type=str,
        default=None,
        help='directory to load check pointed models')

    parser.add_argument(
        '--gpu', type=str, default='0', help='gpu device number')
    
    parser.add_argument(
        '--gpus',
        type=int,
        default=4,
        help='how many gpus'
    )
    parser.add_argument(
        '--distributed_backend',
        type=str,
        default='dp',
        help='supports three options dp, ddp, ddp2'
    )
    parser.add_argument('--overwrite_cache',
                        default=False,
                        action='store_true',
                        help="Overwrite the cached file")

    parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                               help='evaluate model on validation set')
    
    parser.add_argument('--tiny', dest='tiny', action='store_true',
                               help='train on tiny')
    
    parser.add_argument("--weight_decay", default=0.0, type=float,
                        help="Weight deay if we apply some.")
    
    parser.add_argument(
        '--warm_up_steps', type=int, default=1, help='Warm up steps.')
    
    parser.add_argument(
        '--ktltype', type=str, default='mlm', help='ktl type')

    parser.add_argument(
        '--valtype', type=str, default='anli', help='validation dataset type')
    
    
    parser.add_argument('--use_context',
                        default=False,
                        action='store_true',
                        help="Answer embeds uses context")

    parser.add_argument('--use_question',
                        default=False,
                        action='store_true',
                        help="Context embeds uses question")
    
    parser.add_argument('--freeze_bert',
                        default=False,
                        action='store_true',
                        help="Freezes BERT weights")
    
    parser.add_argument('--resume',
                        default=False,
                        action='store_true',
                        help="If resume training from loaded checkpoint")
    
    parser.add_argument(
        '--mlm_maxlen',
        type=int,
        default=70,
        help='mlm max length'
    )

    parser.add_argument(
        '--maxc',
        type=int,
        default=30,
        help='mlm max length'
    )

    parser.add_argument(
        '--maxq',
        type=int,
        default=30,
        help='mlm max length'
    )

    parser.add_argument(
        '--maxa',
        type=int,
        default=30,
        help='mlm max length'
    )

    parser.add_argument(
        '--losstype', type=str, default='l2', help='ktl type')

    return parser


def load_cache(path_to_docmap):
    pickle_in = open(path_to_docmap,"rb")
    cached_obj = pickle.load(pickle_in)
    pickle_in.close()
    return cached_obj

def save_cache(obj,fname):
    with open(fname, 'wb+') as handle:
        pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
def simple_accuracy(preds, labels):
    preds =np.array(preds)
    labels = np.array(labels)
#     print(preds,labels.squeeze(),flush=True)
    acc  = (preds == labels.squeeze())
    return acc.mean()

def npwhere(values,searchval):
    return np.where(values == searchval)[0]

def create_tensordataset(hparams,typet,dataset):
    # check for cache:
    cache_fname = f'{dataset.file_path}:{hparams.ktltype}:{typet}:{hparams.bert_model}:{dataset.max_c}:{dataset.max_q}:{dataset.max_a}:{hparams.tiny}'
    features = {
        "ctxt":[],"cmsk":[],"qsn":[],"qmsk":[],"atxt":[],"amsk":[],"label":[],"lmask":[],
        "nctxt":[],"ncmask":[], "nqtxt":[], "nqmask":[], "natxt":[],"namask":[]
    }            
    for ix,data in tqdm.tqdm(enumerate(dataset),"Loading Dataset"):

        features["ctxt"].append(data[0])
        features["cmsk"].append(data[1])
        features["qsn"].append(data[2])
        features["qmsk"].append(data[3])
        features["atxt"].append(data[4])
        features["amsk"].append(data[5])
        if len(data)<12:
            features["label"].append(data[6])
            features["lmask"].append(data[7])
        else:
            features["nctxt"].append(data[6])
            features["ncmask"].append(data[7])
            features["nqtxt"].append(data[8])
            features["nqmask"].append(data[9])
            features["natxt"].append(data[10])
            features["namask"].append(data[11])
        if hparams.tiny and ix>200:
            break
        
        # print("Saving Dataset at:"+cache_fname,flush=True)
        # save_cache(features,cache_fname)
    
    for k,v in features.items():
        if len(v) ==0:
            continue
        features[k]=torch.stack(v)
        print(f"Shape of:{k}:{features[k].shape}",flush=True)

    if len(features["namask"])==0:
        return TensorDataset(features["ctxt"],features["qsn"],features["atxt"],features["label"],features["lmask"])
    else:
        return TensorDataset(features["ctxt"],features["cmsk"],features["qsn"],features["qmsk"],
                             features["atxt"],features["amsk"],features["nctxt"],features["ncmask"],
                             features["nqtxt"],features["nqmask"],features["natxt"],features["namask"])


def create_valdataset(hparams,typet,dataset):
    # check for cache:
    cache_fname = f'{dataset.file_path}:{hparams.ktltype}:{typet}:{hparams.bert_model}:{dataset.max_c}:{dataset.max_q}:{dataset.max_a}:{hparams.tiny}'

    features = {0:[],1:[],2:[],3:[],4:[],5:[],6:[]}            
    for ix,data in tqdm.tqdm(enumerate(dataset),"Loading Dataset"):
        for dix,x in enumerate(data):
            features[dix].append(x)
        
        if hparams.tiny and ix>100:
            break
        
        # print("Saving Dataset at:"+cache_fname,flush=True)
        # save_cache(features,cache_fname)
    
    for k,v in features.items():
        if len(v) ==0:
            continue
        features[k]=torch.stack(v)
        print(f"Shape of:{k}:{features[k].shape}",flush=True)

    if len(features[6])==0:
        return TensorDataset(features[0],features[1],features[2],features[3],features[4])
    else:
        return TensorDataset(features[0],features[1],features[2],features[3],features[4],features[5],features[6])